Лабораторная работа¶

Naumov Anton (Any0019)

To contact me in telegram: @any0019

В данном задании вашей задачей будет построить и обучить модель CycleGAN для задачи Img2Img.


Про задачу Img2Img:

  • У вас есть 2 множества изображений ($A$ и $B$)
  • Ваша задача - научиться превращать изображения из множества $A$ в изображения из множества $B$ и наоборот
  • Бывают paired и unpaired множества
  • В случае paired множеств - для каждого изображения из $A$ существует конкретное изображение из $B$, в которое оно должно перейти (и наоборот). Примером такой пары множеств может быть (а) спутниковая фотграфия местности vs схематическая карта местности , (б) фотография фасада здания vs схема фасада здания , ...
  • В случае unpaired множеств - между множествами нет конкретных пар. Примером такой пары множеств может быть (а) фотографии с лошадьми vs фотографии с зебрами , (б) фотографии vs рисунки конкретного художника , ...

Про модель CycleGAN:

  • Модель придумана для задачи Img2Img (в первую очередь unpaired версии) и основана на концепции GAN-ов
  • В модели есть 2 генератора: $G_{A -> B}$ (задача которого - принимая на вход изображение из множества $A$ переводить его в изображение из множества $B$) и $G_{B -> A}$ (аналогично в другую сторону)
  • Надо отметить что на вход генераторы не получают никакую дополнительную случайность, что делает обучение таких моделей проще
  • В модели так же есть 2 дискриминатора: $D_{A}$ (задача которого - отличать реальные изображения из $A$ от сгенерированных с помощью $G_{B -> A}$ изображений) и $D_{B}$ (аналогично для $B$)
  • Функцией ошибки будет, как и для GAN-ов, minmax игра между генераторами и дискриминаторами, но к ошибке генераторов добавится ещё и, так называемый, cycle consistency loss, который проверяет, что после двойного перехода изображения не меняются $G_{B -> A} \big( G_{A -> B} ( a ) \big) = a ; \forall a \in A$ и $G_{A -> B} \big( G_{B -> A} ( b ) \big) = b ; \forall b \in B$
  • Так же можно добавлять или нет штраф за то, чтобы $G_{B -> A}$ не меняла изображения из $A$ и наоборот $G_{A -> B}$ не меняла изображения из $B$
  • Оригинальная статья: https://arxiv.org/pdf/1703.10593.pdf

Ваша задача - ноутбук разбит на несколько частей, каждая со своими баллами

  1. Подготовка данных (2 балла) --> требуется выбрать датасет для обучения (дано несколько на выбор) и составить пайплайн подготовки данных
  2. Составление модели (6 баллов) --> требуется собрать нейросеть (3 балла) и создать функции ошибки (3 балла)
  3. Подготовка обучения (2 балла) --> требуется написать шаги обучения и валидации, визуализацию, а так же полный цикл обучения
  4. Обучение (2 балла) --> требуется обучить модель
  5. Сбор своего датасета и обучение модели на нём (3 балла) --> требуется собрать свой датасет и обучить на нём модель

За наиболее интересные и качественные решения в пункте 5 так же предусмотрены дополнительные баллы

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader

from torchvision import datasets, transforms as tr

import numpy as np

import matplotlib.pyplot as plt
%matplotlib inline

# Для тёмной темы jupyter - уберите если не нужно
plt.style.use('dark_background')

from tqdm.auto import tqdm, trange

import os
import sys
import requests
import cv2

from dataclasses import dataclass

1. Подготовка данных (2 балла)¶

1.1 Выбор и скачивание датасета¶

In [2]:
dataset_folder = "data"
num_images_per_split = 5

os.makedirs(dataset_folder, exist_ok=True)

# В этом цикле каждый из предложенных датасетов
#  - скачивается
#  - распаковывается
#  - отрисовывается + meta
#  - удаляется

# Предлагается посмотреть на все предложенные варианты датасетов и затем оставить один,
#  с которым захочется работать больше всего - закомментируйте (или удалите) все, кроме
#  выбранного, а так же закомментируйте строчки с удалением скачанного датасета
for dataset_name in [
    # Unpaired
    #"apple2orange",
    #"summer2winter_yosemite",
    #"horse2zebra",
    #"monet2photo",
    #"cezanne2photo",
    #"ukiyoe2photo",
    #"vangogh2photo",
    # Paired
    #"maps",
    "facades",
]:
    print(f"Dataset '{dataset_name}'")
    url = f"http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/{dataset_name}.zip"
    download_path = os.path.join(dataset_folder, f"{dataset_name}.zip")
    target_folder = os.path.join(dataset_folder, dataset_name)
    
    print("Loading zip file...", end="")
    # Проверяем что нет такого загруженого файла
    if not os.path.isfile(download_path) or os.path.exists(target_folder):
        # # Можно загрузить через requests библиотеку
        # response = requests.get(url)
        # open(download_path, "wb").write(response.content)
        
        # Можно загрузить через wget
        !wget ${url} -O ${download_path}
    print(" --> done!")
    
    print("Unziping...", end="")
    # Распаковываем
    if os.path.exists(target_folder):
        !rm -r "{target_folder}"
    !unzip -qq "{download_path}" -d "{dataset_folder}"
    print(" --> done!")
    
    # Удаляем zip-файл
    !rm ${download_path}
    
    # Meta + Отрисовка
    print(f"Provided splits: {os.listdir(target_folder)}")
    # Удобный способ быстро получить датасет картинок с лэйблами, если картинки разложены по папкам в формате:
    # root_folder/label_name/img.jpg
    # (в нашем случае нет лэйблов, но картинки разложены в таком же формате, просто вместо label_name идёт
    #  split_name: trainA/testA/trainB/testB)
    dataset = datasets.ImageFolder(target_folder)

    inds_to_show = {i: [] for i, _ in enumerate(dataset.classes)}
    classes_full = 0
    for dataset_ind in range(len(dataset)):
        _, split_ind = dataset[dataset_ind]
        if len(inds_to_show[split_ind]) == num_images_per_split:
            continue
        inds_to_show[split_ind].append(dataset_ind)
        if len(inds_to_show[split_ind]) == num_images_per_split:
            classes_full += 1
        if classes_full == len(dataset.classes):
            break

    for split_name in sorted(dataset.classes):
        split_ind = dataset.class_to_idx[split_name]
        print(f"Split '{split_name}' of dataset '{dataset_name}'", end="")
        split_folder = os.path.join(target_folder, split_name)
        print(f" --> size: {len(os.listdir(split_folder))}")

        plt.subplots(1, num_images_per_split, figsize=(5 * num_images_per_split, 5))
        plt.suptitle(f"{dataset_name} ~ {split_name}", y=0.95)
        for i, dataset_ind in enumerate(inds_to_show[split_ind]):
            plt.subplot(1, num_images_per_split, i + 1)
            plt.imshow(dataset[dataset_ind][0])
            plt.xticks([])
            plt.yticks([])
        plt.show()
    
    # Удаляем скачанный датасет
    #!rm -r "{target_folder}"
    
    print("\n----------------------------\n")
Dataset 'facades'
Loading zip file.../facades.zip: Отказано в доступе
 --> done!
Unziping... --> done!
rm: невозможно удалить '/facades.zip': Нет такого файла или каталога
Provided splits: ['trainA', 'testA', 'testB', 'trainB']
Split 'testA' of dataset 'facades' --> size: 106
No description has been provided for this image
Split 'testB' of dataset 'facades' --> size: 106
No description has been provided for this image
Split 'trainA' of dataset 'facades' --> size: 400
No description has been provided for this image
Split 'trainB' of dataset 'facades' --> size: 400
No description has been provided for this image
----------------------------

1.2 Dataset и Transforms¶

В папке target_folder находятся несколько папок со сплитами для соответствующего датасета, в каждой папке сплита находятся сами .jpg изображения.

Давайте составим их в удобном для нас виде в отдельные датасеты для каждого сплита без лэйблов.

Для получения картинок можно использовать

cv2.imread(img_path)[:, :, ::-1]  # каналы записаны в обратном порядке
In [3]:
# Выбранный выше и скачанный датасет
dataset_folder = "data"
dataset_name = "facades"
target_folder = os.path.join(dataset_folder, dataset_name)

# Класс для датасета изображений без лэйблов с применением трансформов
class ImageDatasetNoLabel(Dataset):
    def __init__(self, data_dir, transforms=None):
        super(ImageDatasetNoLabel).__init__()
        self.data_path = data_dir
        self.transform = transforms
        self.data = []

        for path, dir_name, file_name in os.walk(data_dir):
            for name in file_name:
                raw_img = cv2.imread(f'{path}/{name}')[:, :, ::-1]
                self.data.append(raw_img) 

    def __getitem__(self, index):
        if self.transform is not None:
            return self.transform(self.data[index])
        else:
            return self.data[index]
    
    def __len__(self):
        return len(self.data)

# Удобный класс для хранения всех наших датасетов
@dataclass
class DatasetsClass:
    train_a: ImageDatasetNoLabel
    train_b: ImageDatasetNoLabel
    test_a: ImageDatasetNoLabel
    test_b: ImageDatasetNoLabel


# Все датасеты без трансформов - чтобы посчитать статистики
ds = DatasetsClass(
    train_a=ImageDatasetNoLabel(os.path.join(target_folder, "trainA")),
    train_b=ImageDatasetNoLabel(os.path.join(target_folder, "trainB")),
    test_a=ImageDatasetNoLabel(os.path.join(target_folder, "testA")),
    test_b=ImageDatasetNoLabel(os.path.join(target_folder, "testB")),
)
In [4]:
plt.imshow(ds.train_a[0])
Out[4]:
<matplotlib.image.AxesImage at 0x7607847dc5e0>
No description has been provided for this image
In [5]:
def get_channel_statistics(dataset):
    """
    Функция для получения поканальных статистик (среднее и отклонение) по датасету
    """
    data = np.array(dataset)
    channel_mean = data.mean(axis=(0, 1, 2)) / 255
    channel_std = data.std(axis=(0, 1, 2)) / 255
    return channel_mean, channel_std

# Поканальное среднее и отклонение для A
channel_mean_a, channel_std_a = get_channel_statistics(ds.train_a)
print(channel_mean_a, channel_std_a)

# Поканальное среднее и отклонение для B
channel_mean_b, channel_std_b = get_channel_statistics(ds.train_b)
print(channel_mean_b, channel_std_b)
[0.47777248 0.45261024 0.41679276] [0.24359282 0.23623622 0.23668101]
[0.22201852 0.29931295 0.74451913] [0.34625955 0.28596996 0.33676142]
In [6]:
# Функция для получения train и val transform-ов, а так же функции для де-нормализации изображения
def get_transforms(mean, std):
    train_transform = tr.Compose([
        tr.ToPILImage(),
        tr.Resize((256, 256)),
        tr.ToTensor(),
        tr.ColorJitter(brightness=0.2, saturation=0.2, hue=0.1),
        tr.Normalize(mean=mean, std=std)
    ])
    
    val_transform = tr.Compose([
        tr.ToPILImage(),
        tr.Resize((256, 256)),
        tr.ToTensor(),
        tr.Normalize(mean=mean, std=std)
    ])
    
    def de_normalize(img):
        img = img.clone().cpu()
        for im, m, s in zip(img, mean, std):
            im.mul_(s).add_(m)
        img = img.numpy().transpose(1, 2, 0)
        img = np.clip(img * 255, 0, 255).astype(np.uint8)
        return img
    
    return train_transform, val_transform, de_normalize


# Ваши гиперпараметры
hyperparams = dict()

# transform-ы для A и B
train_transform_a, val_transform_a, de_normalize_a = get_transforms(channel_mean_a, channel_std_a, **hyperparams)
train_transform_b, val_transform_b, de_normalize_b = get_transforms(channel_mean_b, channel_std_b, **hyperparams)


# Функция для визуализации transform-ов
def show_examples(dataset, transform, de_norm, num_per_image=3, image_index=0, title=""):
    fig, ax = plt.subplots(1, 1 + num_per_image, figsize=(5 * (1 + num_per_image), 5))
    
    image = dataset[image_index]
    
    plt.suptitle(title, y=0.95)

    plt.subplot(1, 1 + num_per_image, 1)
    plt.imshow(image)
    plt.title("original")

    for i in range(num_per_image):
        plt.subplot(1, 1 + num_per_image, i + 2)
        plt.title(f"#{i}")
        plt.imshow(de_norm(transform(image)))
    plt.show()

# Проверка на адекватность
show_examples(ds.train_a, train_transform_a, de_normalize_a, num_per_image=4, image_index=0, title="A #0")
show_examples(ds.train_a, val_transform_a, de_normalize_a, num_per_image=1, image_index=0, title="A #0 - val")
show_examples(ds.train_a, train_transform_a, de_normalize_a, num_per_image=4, image_index=1, title="A #1")
show_examples(ds.train_a, train_transform_a, de_normalize_a, num_per_image=4, image_index=2, title="A #2")
show_examples(ds.train_b, train_transform_b, de_normalize_b, num_per_image=4, image_index=0, title="B #0")
show_examples(ds.train_b, val_transform_b, de_normalize_b, num_per_image=1, image_index=0, title="B #0 - val")
show_examples(ds.train_b, train_transform_b, de_normalize_b, num_per_image=4, image_index=1, title="B #1")
show_examples(ds.train_b, train_transform_b, de_normalize_b, num_per_image=4, image_index=2, title="B #2")
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
In [7]:
# Все датасеты с трансформами
ds = DatasetsClass(
    train_a=ImageDatasetNoLabel(
        os.path.join(target_folder, "trainA"),
        transforms=train_transform_a,
    ),
    train_b=ImageDatasetNoLabel(
        os.path.join(target_folder, "trainB"),
        transforms=val_transform_b,
    ),
    test_a=ImageDatasetNoLabel(
        os.path.join(target_folder, "testA"),
        transforms=val_transform_a,
    ),
    test_b=ImageDatasetNoLabel(
        os.path.join(target_folder, "testB"),
        transforms=val_transform_b,
    ),
)

1.3 DataLoader¶

In [8]:
@dataclass
class DataLoadersClass:
    train_a: DataLoader
    train_b: DataLoader
    test_a: DataLoader
    test_b: DataLoader

batch_size = 50

dataloaders = DataLoadersClass(
    train_a=DataLoader(
        dataset=ds.train_a,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    ),
    train_b=DataLoader(
        dataset=ds.train_b,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
    ),
    test_a=DataLoader(
        dataset=ds.test_a,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
    ),
    test_b=DataLoader(
        dataset=ds.test_b,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
    ),
)

2. Модель (6 баллов)¶

2.0 Любые вспомогательные модули и классы¶

In [9]:
import segmentation_models_pytorch as smp

2.1 Архитектура сети (3 балла)¶

In [10]:
class CycleGAN(nn.Module):
    def __init__(self):
        super(CycleGAN, self).__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.generator_A2B = smp.Unet(
                                    encoder_name="resnet18",      
                                    encoder_weights='imagenet',
                                    in_channels=3, 
                                    classes=3,
                                    activation='tanh' 
                                ).to(self.device)

        self.generator_B2A = smp.Unet(
                                    encoder_name="resnet18",      
                                    encoder_weights='imagenet',
                                    in_channels=3, 
                                    classes=3,
                                    activation='tanh' 
                                ).to(self.device)

        self.discriminator_A = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
        ).to(self.device)

        self.discriminator_B = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
        ).to(self.device)

    def forward(self, a, b):
        same_b = self.generator_A2B(b)
        same_a = self.generator_B2A(a)
        
        # generate fake images
        fake_b = self.generator_A2B(a)
        fake_a = self.generator_B2A(b)
                
        # generate recovered images (from fakes)
        recovered_a = self.generator_B2A(fake_b)
        recovered_b = self.generator_A2B(fake_a)
        
        # fake predict
        is_fake_a_pred = F.sigmoid(self.descriminator_a(fake_a))
        is_fake_b_pred = F.sigmoid(self.descriminator_b(fake_b))
        
        # real predict
        is_real_a_pred = F.sigmoid(self.descriminator_a(a))
        is_real_b_pred = F.sigmoid(self.descriminator_b(b))
        
        return same_a, same_b, fake_a, fake_b, recovered_a, recovered_b, \
                    is_fake_a_pred, is_fake_b_pred, is_real_a_pred, is_real_b_pred

2.2 Loss (3 балла)¶

$$ \mathbf{L}_{\text{cyc}} \big( G_{A \rightarrow B}, G_{B \rightarrow A} \big) = \mathbb{E}_{a \sim A} \bigg( \Big\| G_{B \rightarrow A} \big( G_{A \rightarrow B} ( a ) \big) - a \Big\|_1 \bigg) + \mathbb{E}_{b \sim B} \bigg( \Big\| G_{A \rightarrow B} \big( G_{B \rightarrow A} ( b ) \big) - b \Big\|_1 \bigg) \longrightarrow \min_{G}$$

In [11]:
class CycleConsistencyLoss(nn.Module):
    """
    Функция ошибки, проверяющая что после двойного перехода через генераторы изображение не изменилось
    """
    def __init__(self, reduction='mean'):
        super(CycleConsistencyLoss, self).__init__()
        self.loss_fn = nn.L1Loss(reduction=reduction)
    
    def forward(self, x, x_rec):
        # Принимает на вход оригинальное изображение и изображение после двойного перехода
        return self.loss_fn(x_rec, x)

$$ \mathbf{L}_{\text{GAN}} \big( G_{A \rightarrow B}, D_{B} \big) = \mathbb{E}_{b \sim B} \log D_{B} (b) + \mathbb{E}_{a \sim A} \log \Big( 1 - D_{B} \big( G_{A \rightarrow B} (a) \big) \Big) \longrightarrow \min_{G} \max_{D}$$

In [12]:
class AdversarialLossCE(nn.Module):
    """
    Стандартная функция ошибки для minmax игры GAN-ов
    """
    def __init__(self, target_real_label=1.0, target_fake_label=0.0):
        super(AdversarialLossCE, self).__init__()
        self.register_buffer('real_label', torch.tensor(target_real_label))
        self.register_buffer('fake_label', torch.tensor(target_fake_label))
        self.loss_fn = nn.BCEWithLogitsLoss()
    
    def forward(self, real_pred, fake_pred=None):
        if fake_pred is None:
            target = self.real_label.expand_as(real_pred)
            return self.loss_fn(real_pred, target)
        else:
            loss_real = self.loss_fn(real_pred, self.real_label.expand_as(real_pred))
            loss_fake = self.loss_fn(fake_pred, self.fake_label.expand_as(fake_pred))
            return loss_real + loss_fake

$$ \mathbb{E}_{b \sim B} \big( D_{B} (b) - 1 \big)^2 + \mathbb{E}_{a \sim A} \Big( D_{B} \big( G_{A \rightarrow B} (a) \big) \Big)^2 \longrightarrow \min_{D} $$ $$ \mathbb{E}_{a \sim A} \Big( D_{B} \big( G_{A \rightarrow B} (a) \big) - 1 \Big)^2 \longrightarrow \min_{G} $$

In [13]:
class AdversarialLossMSE(nn.Module):
    """
    Можно переписать не через CE, а через MSE loss на одном предсказании, помогает со стабильностью
    """
    def __init__(self):
        super(AdversarialLossMSE, self).__init__()
        self.loss_fn = nn.MSELoss()
    
    def forward(self, real_pred, fake_pred=None):
        # Принимает на вход D_{A}(a) - real_pred и D_{A}(G(b)) - fake_pred или наоборот
        # Может принимать только один аргумент для удобности использования в случае
        #  обучения или генератора, или дискриминатора
        if fake_pred is None:
            target = torch.ones_like(real_pred)
            return self.loss_fn(real_pred, target)
        else:
            loss_real = self.loss_fn(real_pred, torch.ones_like(real_pred))
            loss_fake = self.loss_fn(fake_pred, torch.zeros_like(fake_pred))
            return loss_real + loss_fake

$$ \mathbf{L} \big( G_{A \rightarrow B}, G_{B \rightarrow A}, D_{A}, D_{B} \big) = \mathbf{L}_{\text{GAN}} \big( G_{A \rightarrow B}, D_{B} \big) + \mathbf{L}_{\text{GAN}} \big( G_{B \rightarrow A}, D_{A} \big) + \lambda \cdot \mathbf{L}_{\text{cyc}} \big( G_{A \rightarrow B}, G_{B \rightarrow A} \big) \longrightarrow \min_{G} \max_{D}$$

In [14]:
class FullDiscriminatorLoss(nn.Module):
    """
    Полная ошибка для дискриминатора
    """
    def __init__(self, is_mse=True):
        super(FullDiscriminatorLoss, self).__init__()
        self.adversarial_loss_func = AdversarialLossMSE() if is_mse else AdversarialLossCE()
        self.is_mse = is_mse
    
    def forward(
        self,
        pred_real_a,
        pred_fake_a,
        pred_real_b,
        pred_fake_b
    ):
        loss_d_a = self.adversarial_loss(pred_real_a, pred_fake_a)
        loss_d_b = self.adversarial_loss(pred_real_b, pred_fake_b)
        return loss_d_a + loss_d_b
In [15]:
class FullGeneratorLoss(nn.Module):
    """
    Полная ошибка для генератора
    """
    def __init__(self, lambda_value=10., is_mse=True, lambda_identity=0.5):
        super(FullGeneratorLoss, self).__init__()
        self.adversarial_loss_func = AdversarialLossMSE() if is_mse else AdversarialLossCE()
        self.cycle_consistency_loss_func = CycleConsistencyLoss()
        self.lambda_value = lambda_value
        self.lambda_identity = lambda_identity
    
    def forward(
        self,
        pred_fake_a, pred_fake_b, 
        rec_a, real_a, rec_b, real_b,
        same_a=None, same_b=None
    ):
        loss_gan_a = self.adversarial_loss_func(pred_fake_a)
        loss_gan_b = self.adversarial_loss_func(pred_fake_b)
        
        loss_cycle_a = self.cycle_consistency_loss_func(rec_a, real_a)
        loss_cycle_b = self.cycle_consistency_loss_func(rec_b, real_b)
        
        total_loss = loss_gan_a + loss_gan_b + self.lambda_value * (loss_cycle_a + loss_cycle_b)
        
        if same_a is not None and same_b is not None and self.lambda_identity > 0:
            loss_id_a = self.cycle_consistency_loss_func(same_a, real_a)
            loss_id_b = self.cycle_consistency_loss_func(same_b, real_b)
            total_loss += self.lambda_identity * (loss_id_a + loss_id_b)
            
        return total_loss

3. Подготовка обучения (2 балла)¶

3.1 Шаг обучения дискриминатора¶

In [16]:
def train_discriminators(model, opt_d, loader_a, loader_b, criterion_d):
    model.train()
    losses_tr = []
    
    iter_a = iter(loader_a)
    iter_b = iter(loader_b)
    batches_per_epoch = min(len(iter_a), len(iter_b))
    
    for _ in trange(batches_per_epoch):
        imgs_a = next(iter_a).to(device)
        imgs_b = next(iter_b).to(device)
        
        opt_d.zero_grad()
        
        same_a, same_b, fake_a, fake_b, recovered_a, recovered_b, \
                is_fake_a_pred, is_fake_b_pred, is_real_a_pred, is_real_b_pred = model(imgs_a, imgs_b)
        
        loss = criterion_d(is_real_a_pred, is_fake_a_pred, is_real_b_pred, is_fake_b_pred)
        
        loss.backward()
        opt_d.step()
        losses_tr.append(loss.item())
    
    return model, opt_d, np.mean(losses_tr)

3.2 Шаг обучения генератора¶

In [17]:
def train_generators(model, opt_g, loader_a, loader_b, criterion_g):
    model.train()
    losses_tr = []
    
    iter_a = iter(loader_a)
    iter_b = iter(loader_b)
    batches_per_epoch = min(len(iter_a), len(iter_b))
    
    for _ in trange(batches_per_epoch):
        imgs_a = next(iter_a).to(device)
        imgs_b = next(iter_b).to(device)
        
        opt_g.zero_grad()
        
        same_a, same_b, fake_a, fake_b, recovered_a, recovered_b, \
                is_fake_a_pred, is_fake_b_pred, is_real_a_pred, is_real_b_pred = model(imgs_a, imgs_b)

        loss = criterion_g(fake_a, fake_b, recovered_a, imgs_a, recovered_b, imgs_b, same_a, same_b)

        loss.backward()
        opt_g.step()
        losses_tr.append(loss.item())
    
    return model, opt_g, np.mean(losses_tr)

3.3 Шаг валидации¶

In [18]:
from collections import defaultdict

def val(model, loader_a, loader_b, criterion_d, criterion_g):
    model.eval()
    
    val_data = defaultdict(list)

    with torch.no_grad():
        iter_a = iter(loader_a)
        iter_b = iter(loader_b)
        batches_per_epoch = min(len(iter_a), len(iter_b))

        for _ in trange(batches_per_epoch):
            imgs_a = next(iter_a).to(device)
            imgs_b = next(iter_b).to(device)
            
            same_a, same_b, fake_a, fake_b, recovered_a, recovered_b, \
                is_fake_a_pred, is_fake_b_pred, is_real_a_pred, is_real_b_pred = model(imgs_a, imgs_b)
            
            loss_d = criterion_d(is_real_a_pred, is_fake_a_pred, is_real_b_pred, is_fake_b_pred)
            
            loss_g = criterion_g(fake_a, fake_b, recovered_a, imgs_a, recovered_b, imgs_b, same_a, same_b)
            
            val_data["loss D"].append(loss_d.item())
            val_data["loss G"].append(loss_g.item())
            
            # Оставлю для вас мой кусочек логирования для визуализации, думаю по аналогии
            #  разберётесь что предполагалось в каких переменных
            is_mse_pred = a_real_pred.shape[-1] == 1
            
            if is_mse_pred:
                a_real_pred = a_real_pred[:, 0]
                b_real_pred = b_real_pred[:, 0]
                a_fake_pred = a_fake_pred[:, 0]
                b_fake_pred = b_fake_pred[:, 0]
            else:
                a_real_pred = F.softmax(a_real_pred, dim=1)[:, 1]
                b_real_pred = F.softmax(b_real_pred, dim=1)[:, 1]
                a_fake_pred = F.softmax(a_fake_pred, dim=1)[:, 1]
                b_fake_pred = F.softmax(b_fake_pred, dim=1)[:, 1]
            
            val_data["real pred A"].extend(a_real_pred.cpu().detach().tolist())
            val_data["real pred B"].extend(b_real_pred.cpu().detach().tolist())
            val_data["fake pred A"].extend(a_fake_pred.cpu().detach().tolist())
            val_data["fake pred B"].extend(b_fake_pred.cpu().detach().tolist())
        
        val_data["loss D"] = np.mean(val_data["loss D"])
        val_data["loss G"] = np.mean(val_data["loss G"])
    
    return val_data

3.4 Визуализация сгенерированного¶

In [19]:
def draw_imgs(model, num_images, loader_a, loader_b, de_norm_a, de_norm_b):
    model.eval()
    with torch.no_grad():
        imgs_a = next(iter(loader_a))[:num_images].to(device)
        imgs_b = next(iter(loader_b))[:num_images].to(device)
        
        _, _, fake_a, fake_b, rec_a, rec_b, _, _, _, _ = model(imgs_a, imgs_b)
        
        # Draw num_images examples for A
        fig, ax = plt.subplots(num_images, 3, figsize=(25, 15))
        plt.suptitle("Images from A", y=0.92)
        
        for ind in range(num_images):
            plt.subplot(num_images, 3, ind * 3 + 1)
            plt.title("Original from A")
            plt.imshow(de_norm_a(imgs_a[ind], normalized=True))
            plt.xticks([])
            plt.yticks([])
            
            plt.subplot(num_images, 3, ind * 3 + 2)
            plt.title("Translated to B")
            plt.imshow(de_norm_b(fake_b[ind], normalized=True))
            plt.xticks([])
            plt.yticks([])
            
            plt.subplot(num_images, 3, ind * 3 + 3)
            plt.title("Reconstructed A")
            plt.imshow(de_norm_a(rec_a[ind], normalized=True))
            plt.xticks([])
            plt.yticks([])
        
        # Draw num_images examples for B
        fig, ax = plt.subplots(num_images, 3, figsize=(25, 15))
        plt.suptitle("Images from B", y=0.92)
        
        for ind in range(num_images):
            plt.subplot(num_images, 3, ind * 3 + 1)
            plt.title("Original from B")
            plt.imshow(de_norm_b(imgs_b[ind], normalized=True))
            plt.xticks([])
            plt.yticks([])
            
            plt.subplot(num_images, 3, ind * 3 + 2)
            plt.title("Translated to A")
            plt.imshow(de_norm_a(fake_a[ind], normalized=True))
            plt.xticks([])
            plt.yticks([])
            
            plt.subplot(num_images, 3, ind * 3 + 3)
            plt.title("Reconstructed B")
            plt.imshow(de_norm_b(rec_b[ind], normalized=True))
            plt.xticks([])
            plt.yticks([])
    
        plt.show()

3.5 Цикл обучения¶

In [20]:
from IPython.display import clear_output
import warnings

def get_model_name(chkp_folder, model_name=None):
    # Выбираем имя чекпоинта для сохранения
    if model_name is None:
        if os.path.exists(chkp_folder):
            num_starts = len(os.listdir(chkp_folder)) + 1
        else:
            num_starts = 1
        model_name = f'model#{num_starts}'
    else:
        if "#" not in model_name:
            model_name += "#0"
    changed = False
    while os.path.exists(os.path.join(chkp_folder, model_name + '.pt')):
        model_name, ind = model_name.split("#")
        model_name += f"#{int(ind) + 1}"
        changed=True
    if changed:
        warnings.warn(f"Selected model_name was used already! To avoid possible overwrite - model_name changed to {model_name}")
    return model_name


def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']


def learning_loop(
    model,
    optimizer_g,
    g_iters_per_epoch,
    optimizer_d,
    d_iters_per_epoch,
    train_loader_a,
    train_loader_b,
    val_loader_a,
    val_loader_b,
    criterion_d,
    criterion_g,
    de_norm_a,
    de_norm_b,
    scheduler_d=None,
    scheduler_g=None,
    min_lr=None,
    epochs=10,
    val_every=1,
    draw_every=1,
    model_name=None,
    chkp_folder="./chkps",
    images_per_validation=3,
    plots=None,
    starting_epoch=0,
):
    model_name = get_model_name(chkp_folder, model_name)
    
    if plots is None:
        plots = {
            'train G': [],
            'train D': [],
            'val D': [],
            'val G': [],
            "lr G": [],
            "lr D": [],
            "hist real A": [],
            "hist gen A": [],
            "hist real B": [],
            "hist gen B": [],
        }

    for epoch in np.arange(1, epochs+1) + starting_epoch:
        print(f'#{epoch}/{epochs}:')

        plots['lr G'].append(get_lr(optimizer_g))
        plots['lr D'].append(get_lr(optimizer_d))
        
        # train discriminators
        print(f"train discriminators ({d_iters_per_epoch} times)")
        loss_d = []
        for _ in range(d_iters_per_epoch):
            model, optimizer_d, loss = train_discriminators(model, optimizer_d, train_loader_a, train_loader_b, criterion_d)
            loss_d.append(loss)
        plots['train D'].extend(loss_d)
        
        # train generators
        print(f"train generators ({g_iters_per_epoch} times)")
        loss_g = []
        for _ in range(g_iters_per_epoch):
            model, optimizer_g, loss = train_generators(model, optimizer_g, train_loader_a, train_loader_b, criterion_g)
            loss_g.append(loss)
        plots['train G'].extend(loss_g)

        if not (epoch % val_every):
            print("validate")
            val_data = val(model, val_loader_a, val_loader_b, criterion_d, criterion_g)
            plots['val D'].append(val_data["loss D"])
            plots['val G'].append(val_data["loss G"])
            plots['hist real A'].append(val_data["real pred A"])
            plots['hist gen A'].append(val_data["fake pred A"])
            plots['hist real B'].append(val_data["real pred B"])
            plots['hist gen B'].append(val_data["fake pred B"])
            
            # Сохраняем модель
            if not os.path.exists(chkp_folder):
                os.makedirs(chkp_folder)
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_d_state_dict': optimizer_d.state_dict(),
                    'optimizer_g_state_dict': optimizer_g.state_dict(),
                    'scheduler_d_state_dict': scheduler_d.state_dict(),
                    'scheduler_g_state_dict': scheduler_g.state_dict(),
                    'plots': plots,
                },
                os.path.join(chkp_folder, model_name + '.pt'),
            )
            
            # Шедулинг
            if scheduler_d:
                try:
                    scheduler_d.step()
                except:
                    scheduler_d.step(loss_d)
            if scheduler_g:
                try:
                    scheduler_g.step()
                except:
                    scheduler_g.step(loss_g)

        if not (epoch % draw_every):
            clear_output(True)
            
            hh = 2
            ww = 2
            plt_ind = 1
            fig, ax = plt.subplots(hh, ww, figsize=(25, 12))
            fig.suptitle(f'#{epoch}/{epochs}:')

            plt.subplot(hh, ww, plt_ind)
            plt.title('discriminators losses')
            d_plot_step = 1. / d_iters_per_epoch
            plt.plot(np.arange(d_plot_step, epoch + d_plot_step, d_plot_step), plots['train D'], 'r.-', label='train', alpha=0.7)
            plt.plot(np.arange(1, epoch + 1), plots['val D'], 'g.-', label='val', alpha=0.7)
            plt.grid()
            plt.legend()
            plt_ind += 1
            
            plt.subplot(hh, ww, plt_ind)
            plt.title('generators losses')
            g_plot_step = 1. / g_iters_per_epoch
            plt.plot(np.arange(g_plot_step, epoch + g_plot_step, g_plot_step), plots['train G'], 'r.-', label='train', alpha=0.7)
            plt.plot(np.arange(1, epoch + 1), plots['val G'], 'g.-', label='val', alpha=0.7)
            plt.grid()
            plt.legend()
            plt_ind += 1
            
            # plt.subplot(hh, ww, plt_ind)
            # plt.title('learning rates')
            # plt.plot(plots["lr D"], 'b.-', label='lr discriminator', alpha=0.7)
            # plt.plot(plots["lr G"], 'm.-', label='lr generator', alpha=0.7)
            # plt.legend()
            # plt_ind += 1
            
            plt.subplot(hh, ww, plt_ind)
            plt.title("Discriminator A predictions")
            plt.hist(plots["hist real A"][-1], bins=50, density=True, label="real", color="green", alpha=0.7)
            plt.hist(plots["hist gen A"][-1], bins=50, density=True, label="generated", color="red", alpha=0.7)
            plt.xlim((-0.05, 1.05))
            plt.xticks(ticks=np.arange(0, 1.05, 0.1))
            plt.legend()
            plt_ind += 1
            
            plt.subplot(hh, ww, plt_ind)
            plt.title("Discriminator B predictions")
            plt.hist(plots["hist real B"][-1], bins=50, density=True, label="real", color="green", alpha=0.7)
            plt.hist(plots["hist gen B"][-1], bins=50, density=True, label="generated", color="red", alpha=0.7)
            plt.xlim((-0.05, 1.05))
            plt.xticks(ticks=np.arange(0, 1.05, 0.1))
            plt.legend()
            plt_ind += 1
            
            plt.show()
            
            draw_imgs(model, images_per_validation, val_loader_a, val_loader_b, de_norm_a, de_norm_b)
                
        
        if min_lr and get_lr(optimizer_d) <= min_lr:
            print(f'Learning process ended with early stop for discriminator after epoch {epoch}')
            break
        
        if min_lr and get_lr(optimizer_g) <= min_lr:
            print(f'Learning process ended with early stop for generator after epoch {epoch}')
            break
    
    return model, optimizer_d, optimizer_g, plots

4. Обучение (2 балла)¶

4.1 Инициализация модели и оптимайзера¶

In [21]:
from collections import defaultdict
from termcolor import colored


def beautiful_int(i):
    i = str(i)
    return ".".join(reversed([i[max(j, 0):j+3] for j in range(len(i) - 3, -3, -3)]))


# Подсчёт числа параметров в нашей модели
def model_num_params(model, verbose_all=True, verbose_only_learnable=False):
    sum_params = 0
    sum_learnable_params = 0
    submodules = defaultdict(lambda : [0, 0])
    for name, param in model.named_parameters():
        num_params = param.numel()
        if verbose_all or (verbose_only_learnable and param.requires_grad):
            print(
                colored(
                    '{: <65} ~  {: <9} params ~ grad: {}'.format(
                        name,
                        beautiful_int(num_params),
                        param.requires_grad,
                    ),
                    {True: "green", False: "red"}[param.requires_grad],
                )
            )
        sum_params += num_params
        sm = name.split(".")[0]
        submodules[sm][0] += num_params
        if param.requires_grad:
            sum_learnable_params += num_params
            submodules[sm][1] += num_params
    print(
        f'\nIn total:\n  - {beautiful_int(sum_params)} params\n  - {beautiful_int(sum_learnable_params)} learnable params'
    )
    
    for sm, v in submodules.items():
        print(
            f"\n . {sm}:\n .   - {beautiful_int(submodules[sm][0])} params\n .   - {beautiful_int(submodules[sm][1])} learnable params"
        )
    return sum_params, sum_learnable_params
In [22]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

def create_model_and_optimizer(model_class, model_params, lr=1e-3, weight_decay=1e-5, device=device):
    model = model_class(**model_params)
    model = model.to(device)
    
    optimizer_d = torch.optim.Adam(
        list(model.discriminator_A.parameters()) + list(model.discriminator_B.parameters()),
        lr,
        weight_decay=weight_decay,
    )
    optimizer_g = torch.optim.Adam(
        list(model.generator_A2B.parameters()) + list(model.generator_B2A.parameters()),
        lr,
        weight_decay=weight_decay,
    )
    return model, optimizer_d, optimizer_g

4.2 Фактическое обучение¶

In [29]:
device = torch.device('cpu')

results = []

model, optimizer_d, optimizer_g = create_model_and_optimizer(
    model_class = CycleGAN,
    model_params = {},
    lr = 1e-3,
    device = device,
)


scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optimizer_d, gamma=0.9)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optimizer_g, gamma=0.9)

criterion_d = FullDiscriminatorLoss()
criterion_g = FullGeneratorLoss()

sum_params, sum_learnable_params = model_num_params(model)
generator_A2B.encoder.conv1.weight                                ~  9.408     params ~ grad: True
generator_A2B.encoder.bn1.weight                                  ~  64        params ~ grad: True
generator_A2B.encoder.bn1.bias                                    ~  64        params ~ grad: True
generator_A2B.encoder.layer1.0.conv1.weight                       ~  36.864    params ~ grad: True
generator_A2B.encoder.layer1.0.bn1.weight                         ~  64        params ~ grad: True
generator_A2B.encoder.layer1.0.bn1.bias                           ~  64        params ~ grad: True
generator_A2B.encoder.layer1.0.conv2.weight                       ~  36.864    params ~ grad: True
generator_A2B.encoder.layer1.0.bn2.weight                         ~  64        params ~ grad: True
generator_A2B.encoder.layer1.0.bn2.bias                           ~  64        params ~ grad: True
generator_A2B.encoder.layer1.1.conv1.weight                       ~  36.864    params ~ grad: True
generator_A2B.encoder.layer1.1.bn1.weight                         ~  64        params ~ grad: True
generator_A2B.encoder.layer1.1.bn1.bias                           ~  64        params ~ grad: True
generator_A2B.encoder.layer1.1.conv2.weight                       ~  36.864    params ~ grad: True
generator_A2B.encoder.layer1.1.bn2.weight                         ~  64        params ~ grad: True
generator_A2B.encoder.layer1.1.bn2.bias                           ~  64        params ~ grad: True
generator_A2B.encoder.layer2.0.conv1.weight                       ~  73.728    params ~ grad: True
generator_A2B.encoder.layer2.0.bn1.weight                         ~  128       params ~ grad: True
generator_A2B.encoder.layer2.0.bn1.bias                           ~  128       params ~ grad: True
generator_A2B.encoder.layer2.0.conv2.weight                       ~  147.456   params ~ grad: True
generator_A2B.encoder.layer2.0.bn2.weight                         ~  128       params ~ grad: True
generator_A2B.encoder.layer2.0.bn2.bias                           ~  128       params ~ grad: True
generator_A2B.encoder.layer2.0.downsample.0.weight                ~  8.192     params ~ grad: True
generator_A2B.encoder.layer2.0.downsample.1.weight                ~  128       params ~ grad: True
generator_A2B.encoder.layer2.0.downsample.1.bias                  ~  128       params ~ grad: True
generator_A2B.encoder.layer2.1.conv1.weight                       ~  147.456   params ~ grad: True
generator_A2B.encoder.layer2.1.bn1.weight                         ~  128       params ~ grad: True
generator_A2B.encoder.layer2.1.bn1.bias                           ~  128       params ~ grad: True
generator_A2B.encoder.layer2.1.conv2.weight                       ~  147.456   params ~ grad: True
generator_A2B.encoder.layer2.1.bn2.weight                         ~  128       params ~ grad: True
generator_A2B.encoder.layer2.1.bn2.bias                           ~  128       params ~ grad: True
generator_A2B.encoder.layer3.0.conv1.weight                       ~  294.912   params ~ grad: True
generator_A2B.encoder.layer3.0.bn1.weight                         ~  256       params ~ grad: True
generator_A2B.encoder.layer3.0.bn1.bias                           ~  256       params ~ grad: True
generator_A2B.encoder.layer3.0.conv2.weight                       ~  589.824   params ~ grad: True
generator_A2B.encoder.layer3.0.bn2.weight                         ~  256       params ~ grad: True
generator_A2B.encoder.layer3.0.bn2.bias                           ~  256       params ~ grad: True
generator_A2B.encoder.layer3.0.downsample.0.weight                ~  32.768    params ~ grad: True
generator_A2B.encoder.layer3.0.downsample.1.weight                ~  256       params ~ grad: True
generator_A2B.encoder.layer3.0.downsample.1.bias                  ~  256       params ~ grad: True
generator_A2B.encoder.layer3.1.conv1.weight                       ~  589.824   params ~ grad: True
generator_A2B.encoder.layer3.1.bn1.weight                         ~  256       params ~ grad: True
generator_A2B.encoder.layer3.1.bn1.bias                           ~  256       params ~ grad: True
generator_A2B.encoder.layer3.1.conv2.weight                       ~  589.824   params ~ grad: True
generator_A2B.encoder.layer3.1.bn2.weight                         ~  256       params ~ grad: True
generator_A2B.encoder.layer3.1.bn2.bias                           ~  256       params ~ grad: True
generator_A2B.encoder.layer4.0.conv1.weight                       ~  1.179.648 params ~ grad: True
generator_A2B.encoder.layer4.0.bn1.weight                         ~  512       params ~ grad: True
generator_A2B.encoder.layer4.0.bn1.bias                           ~  512       params ~ grad: True
generator_A2B.encoder.layer4.0.conv2.weight                       ~  2.359.296 params ~ grad: True
generator_A2B.encoder.layer4.0.bn2.weight                         ~  512       params ~ grad: True
generator_A2B.encoder.layer4.0.bn2.bias                           ~  512       params ~ grad: True
generator_A2B.encoder.layer4.0.downsample.0.weight                ~  131.072   params ~ grad: True
generator_A2B.encoder.layer4.0.downsample.1.weight                ~  512       params ~ grad: True
generator_A2B.encoder.layer4.0.downsample.1.bias                  ~  512       params ~ grad: True
generator_A2B.encoder.layer4.1.conv1.weight                       ~  2.359.296 params ~ grad: True
generator_A2B.encoder.layer4.1.bn1.weight                         ~  512       params ~ grad: True
generator_A2B.encoder.layer4.1.bn1.bias                           ~  512       params ~ grad: True
generator_A2B.encoder.layer4.1.conv2.weight                       ~  2.359.296 params ~ grad: True
generator_A2B.encoder.layer4.1.bn2.weight                         ~  512       params ~ grad: True
generator_A2B.encoder.layer4.1.bn2.bias                           ~  512       params ~ grad: True
generator_A2B.decoder.blocks.0.conv1.0.weight                     ~  1.769.472 params ~ grad: True
generator_A2B.decoder.blocks.0.conv1.1.weight                     ~  256       params ~ grad: True
generator_A2B.decoder.blocks.0.conv1.1.bias                       ~  256       params ~ grad: True
generator_A2B.decoder.blocks.0.conv2.0.weight                     ~  589.824   params ~ grad: True
generator_A2B.decoder.blocks.0.conv2.1.weight                     ~  256       params ~ grad: True
generator_A2B.decoder.blocks.0.conv2.1.bias                       ~  256       params ~ grad: True
generator_A2B.decoder.blocks.1.conv1.0.weight                     ~  442.368   params ~ grad: True
generator_A2B.decoder.blocks.1.conv1.1.weight                     ~  128       params ~ grad: True
generator_A2B.decoder.blocks.1.conv1.1.bias                       ~  128       params ~ grad: True
generator_A2B.decoder.blocks.1.conv2.0.weight                     ~  147.456   params ~ grad: True
generator_A2B.decoder.blocks.1.conv2.1.weight                     ~  128       params ~ grad: True
generator_A2B.decoder.blocks.1.conv2.1.bias                       ~  128       params ~ grad: True
generator_A2B.decoder.blocks.2.conv1.0.weight                     ~  110.592   params ~ grad: True
generator_A2B.decoder.blocks.2.conv1.1.weight                     ~  64        params ~ grad: True
generator_A2B.decoder.blocks.2.conv1.1.bias                       ~  64        params ~ grad: True
generator_A2B.decoder.blocks.2.conv2.0.weight                     ~  36.864    params ~ grad: True
generator_A2B.decoder.blocks.2.conv2.1.weight                     ~  64        params ~ grad: True
generator_A2B.decoder.blocks.2.conv2.1.bias                       ~  64        params ~ grad: True
generator_A2B.decoder.blocks.3.conv1.0.weight                     ~  36.864    params ~ grad: True
generator_A2B.decoder.blocks.3.conv1.1.weight                     ~  32        params ~ grad: True
generator_A2B.decoder.blocks.3.conv1.1.bias                       ~  32        params ~ grad: True
generator_A2B.decoder.blocks.3.conv2.0.weight                     ~  9.216     params ~ grad: True
generator_A2B.decoder.blocks.3.conv2.1.weight                     ~  32        params ~ grad: True
generator_A2B.decoder.blocks.3.conv2.1.bias                       ~  32        params ~ grad: True
generator_A2B.decoder.blocks.4.conv1.0.weight                     ~  4.608     params ~ grad: True
generator_A2B.decoder.blocks.4.conv1.1.weight                     ~  16        params ~ grad: True
generator_A2B.decoder.blocks.4.conv1.1.bias                       ~  16        params ~ grad: True
generator_A2B.decoder.blocks.4.conv2.0.weight                     ~  2.304     params ~ grad: True
generator_A2B.decoder.blocks.4.conv2.1.weight                     ~  16        params ~ grad: True
generator_A2B.decoder.blocks.4.conv2.1.bias                       ~  16        params ~ grad: True
generator_A2B.segmentation_head.0.weight                          ~  432       params ~ grad: True
generator_A2B.segmentation_head.0.bias                            ~  3         params ~ grad: True
generator_B2A.encoder.conv1.weight                                ~  9.408     params ~ grad: True
generator_B2A.encoder.bn1.weight                                  ~  64        params ~ grad: True
generator_B2A.encoder.bn1.bias                                    ~  64        params ~ grad: True
generator_B2A.encoder.layer1.0.conv1.weight                       ~  36.864    params ~ grad: True
generator_B2A.encoder.layer1.0.bn1.weight                         ~  64        params ~ grad: True
generator_B2A.encoder.layer1.0.bn1.bias                           ~  64        params ~ grad: True
generator_B2A.encoder.layer1.0.conv2.weight                       ~  36.864    params ~ grad: True
generator_B2A.encoder.layer1.0.bn2.weight                         ~  64        params ~ grad: True
generator_B2A.encoder.layer1.0.bn2.bias                           ~  64        params ~ grad: True
generator_B2A.encoder.layer1.1.conv1.weight                       ~  36.864    params ~ grad: True
generator_B2A.encoder.layer1.1.bn1.weight                         ~  64        params ~ grad: True
generator_B2A.encoder.layer1.1.bn1.bias                           ~  64        params ~ grad: True
generator_B2A.encoder.layer1.1.conv2.weight                       ~  36.864    params ~ grad: True
generator_B2A.encoder.layer1.1.bn2.weight                         ~  64        params ~ grad: True
generator_B2A.encoder.layer1.1.bn2.bias                           ~  64        params ~ grad: True
generator_B2A.encoder.layer2.0.conv1.weight                       ~  73.728    params ~ grad: True
generator_B2A.encoder.layer2.0.bn1.weight                         ~  128       params ~ grad: True
generator_B2A.encoder.layer2.0.bn1.bias                           ~  128       params ~ grad: True
generator_B2A.encoder.layer2.0.conv2.weight                       ~  147.456   params ~ grad: True
generator_B2A.encoder.layer2.0.bn2.weight                         ~  128       params ~ grad: True
generator_B2A.encoder.layer2.0.bn2.bias                           ~  128       params ~ grad: True
generator_B2A.encoder.layer2.0.downsample.0.weight                ~  8.192     params ~ grad: True
generator_B2A.encoder.layer2.0.downsample.1.weight                ~  128       params ~ grad: True
generator_B2A.encoder.layer2.0.downsample.1.bias                  ~  128       params ~ grad: True
generator_B2A.encoder.layer2.1.conv1.weight                       ~  147.456   params ~ grad: True
generator_B2A.encoder.layer2.1.bn1.weight                         ~  128       params ~ grad: True
generator_B2A.encoder.layer2.1.bn1.bias                           ~  128       params ~ grad: True
generator_B2A.encoder.layer2.1.conv2.weight                       ~  147.456   params ~ grad: True
generator_B2A.encoder.layer2.1.bn2.weight                         ~  128       params ~ grad: True
generator_B2A.encoder.layer2.1.bn2.bias                           ~  128       params ~ grad: True
generator_B2A.encoder.layer3.0.conv1.weight                       ~  294.912   params ~ grad: True
generator_B2A.encoder.layer3.0.bn1.weight                         ~  256       params ~ grad: True
generator_B2A.encoder.layer3.0.bn1.bias                           ~  256       params ~ grad: True
generator_B2A.encoder.layer3.0.conv2.weight                       ~  589.824   params ~ grad: True
generator_B2A.encoder.layer3.0.bn2.weight                         ~  256       params ~ grad: True
generator_B2A.encoder.layer3.0.bn2.bias                           ~  256       params ~ grad: True
generator_B2A.encoder.layer3.0.downsample.0.weight                ~  32.768    params ~ grad: True
generator_B2A.encoder.layer3.0.downsample.1.weight                ~  256       params ~ grad: True
generator_B2A.encoder.layer3.0.downsample.1.bias                  ~  256       params ~ grad: True
generator_B2A.encoder.layer3.1.conv1.weight                       ~  589.824   params ~ grad: True
generator_B2A.encoder.layer3.1.bn1.weight                         ~  256       params ~ grad: True
generator_B2A.encoder.layer3.1.bn1.bias                           ~  256       params ~ grad: True
generator_B2A.encoder.layer3.1.conv2.weight                       ~  589.824   params ~ grad: True
generator_B2A.encoder.layer3.1.bn2.weight                         ~  256       params ~ grad: True
generator_B2A.encoder.layer3.1.bn2.bias                           ~  256       params ~ grad: True
generator_B2A.encoder.layer4.0.conv1.weight                       ~  1.179.648 params ~ grad: True
generator_B2A.encoder.layer4.0.bn1.weight                         ~  512       params ~ grad: True
generator_B2A.encoder.layer4.0.bn1.bias                           ~  512       params ~ grad: True
generator_B2A.encoder.layer4.0.conv2.weight                       ~  2.359.296 params ~ grad: True
generator_B2A.encoder.layer4.0.bn2.weight                         ~  512       params ~ grad: True
generator_B2A.encoder.layer4.0.bn2.bias                           ~  512       params ~ grad: True
generator_B2A.encoder.layer4.0.downsample.0.weight                ~  131.072   params ~ grad: True
generator_B2A.encoder.layer4.0.downsample.1.weight                ~  512       params ~ grad: True
generator_B2A.encoder.layer4.0.downsample.1.bias                  ~  512       params ~ grad: True
generator_B2A.encoder.layer4.1.conv1.weight                       ~  2.359.296 params ~ grad: True
generator_B2A.encoder.layer4.1.bn1.weight                         ~  512       params ~ grad: True
generator_B2A.encoder.layer4.1.bn1.bias                           ~  512       params ~ grad: True
generator_B2A.encoder.layer4.1.conv2.weight                       ~  2.359.296 params ~ grad: True
generator_B2A.encoder.layer4.1.bn2.weight                         ~  512       params ~ grad: True
generator_B2A.encoder.layer4.1.bn2.bias                           ~  512       params ~ grad: True
generator_B2A.decoder.blocks.0.conv1.0.weight                     ~  1.769.472 params ~ grad: True
generator_B2A.decoder.blocks.0.conv1.1.weight                     ~  256       params ~ grad: True
generator_B2A.decoder.blocks.0.conv1.1.bias                       ~  256       params ~ grad: True
generator_B2A.decoder.blocks.0.conv2.0.weight                     ~  589.824   params ~ grad: True
generator_B2A.decoder.blocks.0.conv2.1.weight                     ~  256       params ~ grad: True
generator_B2A.decoder.blocks.0.conv2.1.bias                       ~  256       params ~ grad: True
generator_B2A.decoder.blocks.1.conv1.0.weight                     ~  442.368   params ~ grad: True
generator_B2A.decoder.blocks.1.conv1.1.weight                     ~  128       params ~ grad: True
generator_B2A.decoder.blocks.1.conv1.1.bias                       ~  128       params ~ grad: True
generator_B2A.decoder.blocks.1.conv2.0.weight                     ~  147.456   params ~ grad: True
generator_B2A.decoder.blocks.1.conv2.1.weight                     ~  128       params ~ grad: True
generator_B2A.decoder.blocks.1.conv2.1.bias                       ~  128       params ~ grad: True
generator_B2A.decoder.blocks.2.conv1.0.weight                     ~  110.592   params ~ grad: True
generator_B2A.decoder.blocks.2.conv1.1.weight                     ~  64        params ~ grad: True
generator_B2A.decoder.blocks.2.conv1.1.bias                       ~  64        params ~ grad: True
generator_B2A.decoder.blocks.2.conv2.0.weight                     ~  36.864    params ~ grad: True
generator_B2A.decoder.blocks.2.conv2.1.weight                     ~  64        params ~ grad: True
generator_B2A.decoder.blocks.2.conv2.1.bias                       ~  64        params ~ grad: True
generator_B2A.decoder.blocks.3.conv1.0.weight                     ~  36.864    params ~ grad: True
generator_B2A.decoder.blocks.3.conv1.1.weight                     ~  32        params ~ grad: True
generator_B2A.decoder.blocks.3.conv1.1.bias                       ~  32        params ~ grad: True
generator_B2A.decoder.blocks.3.conv2.0.weight                     ~  9.216     params ~ grad: True
generator_B2A.decoder.blocks.3.conv2.1.weight                     ~  32        params ~ grad: True
generator_B2A.decoder.blocks.3.conv2.1.bias                       ~  32        params ~ grad: True
generator_B2A.decoder.blocks.4.conv1.0.weight                     ~  4.608     params ~ grad: True
generator_B2A.decoder.blocks.4.conv1.1.weight                     ~  16        params ~ grad: True
generator_B2A.decoder.blocks.4.conv1.1.bias                       ~  16        params ~ grad: True
generator_B2A.decoder.blocks.4.conv2.0.weight                     ~  2.304     params ~ grad: True
generator_B2A.decoder.blocks.4.conv2.1.weight                     ~  16        params ~ grad: True
generator_B2A.decoder.blocks.4.conv2.1.bias                       ~  16        params ~ grad: True
generator_B2A.segmentation_head.0.weight                          ~  432       params ~ grad: True
generator_B2A.segmentation_head.0.bias                            ~  3         params ~ grad: True
discriminator_A.0.weight                                          ~  3.072     params ~ grad: True
discriminator_A.0.bias                                            ~  64        params ~ grad: True
discriminator_A.2.weight                                          ~  131.072   params ~ grad: True
discriminator_A.2.bias                                            ~  128       params ~ grad: True
discriminator_A.5.weight                                          ~  524.288   params ~ grad: True
discriminator_A.5.bias                                            ~  256       params ~ grad: True
discriminator_A.8.weight                                          ~  4.096     params ~ grad: True
discriminator_A.8.bias                                            ~  1         params ~ grad: True
discriminator_B.0.weight                                          ~  3.072     params ~ grad: True
discriminator_B.0.bias                                            ~  64        params ~ grad: True
discriminator_B.2.weight                                          ~  131.072   params ~ grad: True
discriminator_B.2.bias                                            ~  128       params ~ grad: True
discriminator_B.5.weight                                          ~  524.288   params ~ grad: True
discriminator_B.5.bias                                            ~  256       params ~ grad: True
discriminator_B.8.weight                                          ~  4.096     params ~ grad: True
discriminator_B.8.bias                                            ~  1         params ~ grad: True

In total:
  - 29.982.952 params
  - 29.982.952 learnable params

 . generator_A2B:
 .   - 14.328.499 params
 .   - 14.328.499 learnable params

 . generator_B2A:
 .   - 14.328.499 params
 .   - 14.328.499 learnable params

 . discriminator_A:
 .   - 662.977 params
 .   - 662.977 learnable params

 . discriminator_B:
 .   - 662.977 params
 .   - 662.977 learnable params
In [30]:
model
Out[30]:
CycleGAN(
  (generator_A2B): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (decoder): UnetDecoder(
      (center): Identity()
      (blocks): ModuleList(
        (0): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (1): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (2): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (3): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (4): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
      )
    )
    (segmentation_head): SegmentationHead(
      (0): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Identity()
      (2): Activation(
        (activation): Tanh()
      )
    )
  )
  (generator_B2A): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (decoder): UnetDecoder(
      (center): Identity()
      (blocks): ModuleList(
        (0): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (1): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (2): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (3): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (4): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
      )
    )
    (segmentation_head): SegmentationHead(
      (0): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Identity()
      (2): Activation(
        (activation): Tanh()
      )
    )
  )
  (discriminator_A): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
  (discriminator_B): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)
In [ ]:
%%time
model, optimizer_d, optimizer_g, plots = learning_loop(
    model = model,
    optimizer_g = optimizer_g,
    g_iters_per_epoch = 1,
    optimizer_d = optimizer_d,
    d_iters_per_epoch = 1,
    train_loader_a = dataloaders.train_a,
    train_loader_b = dataloaders.train_b,
    val_loader_a = dataloaders.test_a,
    val_loader_b = dataloaders.test_b,
    criterion_d = criterion_d,
    criterion_g = criterion_g,
    scheduler_g = scheduler_g,
    scheduler_d = scheduler_d,
    de_norm_a = de_normalize_a,
    de_norm_b = de_normalize_b,
    epochs = 100,
    min_lr = 1e-6,
    val_every = 1,
    draw_every = 1,
    chkp_folder = "./chkp",
    model_name = "cycle_gan",
    images_per_validation=3,
    plots=None,
    starting_epoch=0,
)
#1/100:
train discriminators (1 times)
  0%|          | 0/8 [00:00<?, ?it/s]
In [26]:
model
Out[26]:
CycleGAN(
  (generator_A2B): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (decoder): UnetDecoder(
      (center): Identity()
      (blocks): ModuleList(
        (0): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (1): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (2): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (3): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (4): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
      )
    )
    (segmentation_head): SegmentationHead(
      (0): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Identity()
      (2): Activation(
        (activation): Tanh()
      )
    )
  )
  (generator_B2A): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer2): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer3): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (layer4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (downsample): Sequential(
            (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
            (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          )
        )
        (1): BasicBlock(
          (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (decoder): UnetDecoder(
      (center): Identity()
      (blocks): ModuleList(
        (0): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(768, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (1): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(384, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (2): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (3): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
        (4): DecoderBlock(
          (conv1): Conv2dReLU(
            (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention1): Attention(
            (attention): Identity()
          )
          (conv2): Conv2dReLU(
            (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (attention2): Attention(
            (attention): Identity()
          )
        )
      )
    )
    (segmentation_head): SegmentationHead(
      (0): Conv2d(16, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): Identity()
      (2): Activation(
        (activation): Tanh()
      )
    )
  )
  (discriminator_A): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
  (discriminator_B): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)
In [27]:
%%time
model, optimizer_d, optimizer_g, plots = learning_loop(
    model = model,
    optimizer_g = optimizer_g,
    g_iters_per_epoch = 1,
    optimizer_d = optimizer_d,
    d_iters_per_epoch = 1,
    train_loader_a = dataloaders.train_a,
    train_loader_b = dataloaders.train_b,
    val_loader_a = dataloaders.test_a,
    val_loader_b = dataloaders.test_b,
    criterion_d = criterion_d,
    criterion_g = criterion_g,
    scheduler_g = scheduler_g,
    scheduler_d = scheduler_d,
    de_norm_a = de_normalize_a,
    de_norm_b = de_normalize_b,
    epochs = 100,
    min_lr = 1e-6,
    val_every = 1,
    draw_every = 1,
    chkp_folder = "./chkp",
    model_name = "cycle_gan",
    images_per_validation=3,
    plots=None,
    starting_epoch=0,
)
#1/100:
train discriminators (1 times)
  0%|          | 0/8 [00:00<?, ?it/s]
---------------------------------------------------------------------------
OutOfMemoryError                          Traceback (most recent call last)
File <timed exec>:1

Cell In[20], line 82, in learning_loop(model, optimizer_g, g_iters_per_epoch, optimizer_d, d_iters_per_epoch, train_loader_a, train_loader_b, val_loader_a, val_loader_b, criterion_d, criterion_g, de_norm_a, de_norm_b, scheduler_d, scheduler_g, min_lr, epochs, val_every, draw_every, model_name, chkp_folder, images_per_validation, plots, starting_epoch)
     80 loss_d = []
     81 for _ in range(d_iters_per_epoch):
---> 82     model, optimizer_d, loss = train_discriminators(model, optimizer_d, train_loader_a, train_loader_b, criterion_d)
     83     loss_d.append(loss)
     84 plots['train D'].extend(loss_d)

Cell In[16], line 16, in train_discriminators(model, opt_d, loader_a, loader_b, criterion_d)
     11 imgs_b = next(iter_b).to(device)
     13 opt_d.zero_grad()
     15 same_a, same_b, fake_a, fake_b, recovered_a, recovered_b, \
---> 16         is_fake_a_pred, is_fake_b_pred, is_real_a_pred, is_real_b_pred = model(imgs_a, imgs_b)
     18 loss = criterion_d(is_real_a_pred, is_fake_a_pred, is_real_b_pred, is_fake_b_pred)
     20 loss.backward()

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

Cell In[10], line 47, in CycleGAN.forward(self, a, b)
     46 def forward(self, a, b):
---> 47     same_b = self.generator_A2B(b)
     48     same_a = self.generator_B2A(a)
     50     # generate fake images

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.local/lib/python3.10/site-packages/segmentation_models_pytorch/base/model.py:48, in SegmentationModel.forward(self, x)
     45 if not torch.jit.is_tracing() or self.requires_divisible_input_shape:
     46     self.check_input_shape(x)
---> 48 features = self.encoder(x)
     49 decoder_output = self.decoder(*features)
     51 masks = self.segmentation_head(decoder_output)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.local/lib/python3.10/site-packages/segmentation_models_pytorch/encoders/resnet.py:63, in ResNetEncoder.forward(self, x)
     61 features = []
     62 for i in range(self._depth + 1):
---> 63     x = stages[i](x)
     64     features.append(x)
     66 return features

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input)
    215 def forward(self, input):
    216     for module in self:
--> 217         input = module(input)
    218     return input

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:460, in Conv2d.forward(self, input)
    459 def forward(self, input: Tensor) -> Tensor:
--> 460     return self._conv_forward(input, self.weight, self.bias)

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/conv.py:456, in Conv2d._conv_forward(self, input, weight, bias)
    452 if self.padding_mode != 'zeros':
    453     return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
    454                     weight, bias, self.stride,
    455                     _pair(0), self.dilation, self.groups)
--> 456 return F.conv2d(input, weight, bias, self.stride,
    457                 self.padding, self.dilation, self.groups)

OutOfMemoryError: CUDA out of memory. Tried to allocate 200.00 MiB. GPU 
In [28]:
img_a = ds.test_a[1].to(device).unsqueeze(0)

plt.subplots(1, 2, figsize=(20, 10))

plt.subplot(121)
plt.imshow(de_normalize_a(img_a[0]))

plt.subplot(122)
plt.imshow(de_normalize_b(model.generators["a_to_b"](img_a)[0]))

plt.show()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[28], line 9
      6 plt.imshow(de_normalize_a(img_a[0]))
      8 plt.subplot(122)
----> 9 plt.imshow(de_normalize_b(model.generators["a_to_b"](img_a)[0]))
     11 plt.show()

File ~/.local/lib/python3.10/site-packages/torch/nn/modules/module.py:1709, in Module.__getattr__(self, name)
   1707     if name in modules:
   1708         return modules[name]
-> 1709 raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

AttributeError: 'CycleGAN' object has no attribute 'generators'
No description has been provided for this image
In [ ]:
img_a = ds.test_a[5].to(device).unsqueeze(0)

plt.subplots(1, 2, figsize=(20, 10))

plt.subplot(121)
plt.imshow(de_normalize_a(img_a[0]))

plt.subplot(122)
plt.imshow(de_normalize_b(model.generators["a_to_b"](img_a)[0]))

plt.show()
In [ ]:
img_ind = 1
img_a = ds.test_a[img_ind].to(device).unsqueeze(0)
fake_b = ...(img_a)...

plt.subplots(1, 2, figsize=(20, 10))

plt.subplot(121)
plt.imshow(de_normalize_a(img_a[0]))

plt.subplot(122)
plt.imshow(de_normalize_b(fake_b))

plt.show()
In [ ]:
img_ind = 0
img_b = ds.test_b[img_ind].to(device).unsqueeze(0)
fake_a = ...(img_b)...

plt.subplots(1, 2, figsize=(20, 10))

plt.subplot(121)
plt.imshow(de_normalize_b(img_b[0]))

plt.subplot(122)
plt.imshow(de_normalize_a(fake_a))

plt.show()

5. Свои данные (3 балла + бонусы)¶

В этой части необходимо выполнить следующее:

  1. Создание датасетов

    • Соберите один или пару небольших датасета (минимум 100 примеров, но чем больше, тем лучше).
    • Дайте волю фантазии, но избегайте неадекватного контента, объема данных или данных за рамками норм.
  2. Обучение CycleGAN

    • Обучите модель CycleGAN для преобразования между вашим датасетом (домен A) и другим датасетом (домен B).
    • Второй датасет может быть вашим собственным или любым существующим
  3. Требования к сдаче

    • Архив с датасетами:
      Приложите ссылку на заархивированные данные или загрузите сам архив с данными, использованные для обучения.
    • Jupyter Notebook:
      • Визуализируйте примеры из доменов A и B с кратким описанием идеи (например, "преобразование эскизов в цветные рисунки").
      • Добавьте код обучения модели (архитектура, гиперпараметры, функция потерь).
      • Покажите результаты работы модели (минимум по 5 примеров преобразований A→B и B→A).
    • Hugging Face Space + Streamlit:
      • Разработайте интерактивное приложение с использованием Streamlit, которое позволяет:
        • Загружать изображения из доменов A и B.
        • Отображать результаты преобразований (A→B и B→A) в реальном времени.
      • Выложите приложение в Hugging Face Space и приложите ссылку на него.
      • Убедитесь, что модель интегрирована в приложение, а проверяющие могут самостоятельно тестировать её через интерфейс.
  4. Критерии оценки

    • Дополнительные баллы начисляются за:
      • Креативные и качественные датасеты.
      • Высокое качество преобразований (четкость, сохранение структуры, отсутствие артефактов).
      • Удобный и наглядный интерфейс Streamlit-приложения.

Примеры идей для датасетов:

  • Эскизы → Реалистичные изображения.
  • Дневные фото → Ночные фото.
  • Картины в стиле импрессионизма → фотореализм.

Важно:

  • Проверяющий будет оценивать работу через ваше Streamlit-приложение. Убедитесь, что инференс работает стабильно.
  • Если модель слишком велика для деплоя, используйте оптимизацию.

Удачи! 🚀

In [ ]: